package com.wsjzzcbq.wsjzcode.compile.java;

import com.wsjzzcbq.wsjzcode.compile.JavaCodeCompiler;
import com.wsjzzcbq.wsjzcode.consts.LangTypeEnum;
import com.wsjzzcbq.wsjzcode.util.JavaCodeUitls;
import org.springframework.stereotype.Component;
import javax.tools.*;
import java.io.ByteArrayOutputStream;
import java.io.FilterOutputStream;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URI;
import java.nio.CharBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.io.OutputStream;

/**
 * JavaCodeMemoryCompiler
 * 内存编译java
 *
 * @author wsjz
 * @date 2023/07/11
 */
@Component
public class JavaCodeMemoryCompiler implements JavaCodeCompiler {

    private final ThreadLocal<Map<String, byte[]>> classBytes = ThreadLocal.withInitial(HashMap::new);

    @Override
    public Map<String, Object> compile(String code) {
        JavaCompiler javac = ToolProvider.getSystemJavaCompiler();
        StandardJavaFileManager standardJavaFileManager = javac.getStandardFileManager(null, null, null);
        MemoryJavaFileManager fileManager = new MemoryJavaFileManager(standardJavaFileManager);

        DiagnosticCollector<JavaFileObject> diagnosticListener = new DiagnosticCollector<>();

        List<JavaFileObject> compilationUnits = new ArrayList<>(1);
        String className = JavaCodeUitls.getClassName(code);
        compilationUnits.add(new MemoryInJavaFileObject(URI.create(className + LangTypeEnum.Java.getValue()), JavaFileObject.Kind.SOURCE, code));
        JavaCompiler.CompilationTask compilationTask = javac.getTask(null, fileManager, diagnosticListener, null, null, compilationUnits);
        boolean result = compilationTask.call();

        Map<String, Object> map = new HashMap<>(2);
        if (result) {
            map.put("result", true);
            return map;
        }
        StringBuilder sb = new StringBuilder();
        for (Diagnostic diagnostic : diagnosticListener.getDiagnostics()) {
            sb.append(diagnostic.toString());
        }
        map.put("result", false);
        map.put("error", sb.toString());
        return map;
    }

    @Override
    public Object run(String className, String methodName) throws ClassNotFoundException, NoSuchMethodException, InvocationTargetException, IllegalAccessException, InstantiationException {
        MemoryClassLoader memoryClassLoader = new MemoryClassLoader();
        Class cls = memoryClassLoader.loadClass(className);
        Object obj = cls.getConstructor().newInstance();
        Method method = cls.getMethod(methodName);
        Object ret = method.invoke(obj);
        return ret;
    }
    private class MemoryJavaFileManager extends ForwardingJavaFileManager {

        public MemoryJavaFileManager(JavaFileManager fileManager) {
            super(fileManager);
        }

        @Override
        public JavaFileObject getJavaFileForOutput(Location location, String className, JavaFileObject.Kind kind, FileObject sibling) throws IOException {
            return new ClassOutJavaFileObject(URI.create(className), JavaFileObject.Kind.CLASS, className);
        }
    }

    private class MemoryInJavaFileObject extends SimpleJavaFileObject {
        private String code;
        public MemoryInJavaFileObject(URI uri, Kind kind, String code) {
            super(uri, kind);
            this.code = code;
        }

        @Override
        public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException {
            return CharBuffer.wrap(code);
        }
    }

    private class ClassOutJavaFileObject extends SimpleJavaFileObject {
        //类名
        private String className;
        protected ClassOutJavaFileObject(URI uri, Kind kind, String className) {
            super(uri, kind);
            this.className=className;
        }
        @Override
        public OutputStream openOutputStream() throws IOException {
            return new FilterOutputStream(new ByteArrayOutputStream()) {

                @Override
                public void close() throws IOException {
                    out.close();
                    ByteArrayOutputStream bos = (ByteArrayOutputStream) out;
                    classBytes.get().put(className, bos.toByteArray());
                }
            };
        }

    }

    /**
     * 自定义ClassLoader进行类加载
     */
    private class MemoryClassLoader extends ClassLoader {
        /**
         * findClass是protected，因此只能继承实现自己的类加载器去加载
         * classLoader的 loadClass 方法会调到 findClass 方法
         * @param name
         * @return
         * @throws ClassNotFoundException
         */
        @Override
        protected Class<?> findClass(String name) throws ClassNotFoundException {
            byte[] buf = classBytes.get().get(name);

            try {
                return defineClass(name, buf, 0, buf.length);
            } finally {
                classBytes.remove();
            }
        }
    }
}
